{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5rmpybwysXGV"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "m8y3rGtQsYP2"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hrXv0rU9sIma"
      },
      "source": [
        "# TensorFlow basics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7S0BwJ_8sLu7"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/basics\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/basics.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/basics.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/basics.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iJyZUDbzBTIG"
      },
      "source": [
        "This guide provides a quick overview of _TensorFlow basics_. Each section of this doc is an overview of a larger topic—you can find links to full guides at the end of each section.\n",
        "\n",
        "TensorFlow is an end-to-end platform for machine learning. It supports the following:\n",
        "\n",
        "* Multidimensional-array based numeric computation (similar to <a href=\"https://numpy.org/\" class=\"external\">NumPy</a>.)\n",
        "* GPU and distributed processing\n",
        "* Automatic differentiation\n",
        "* Model construction, training, and export\n",
        "* And more"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gvLegMMvBZYg"
      },
      "source": [
        "## Tensors\n",
        "\n",
        "TensorFlow operates on multidimensional arrays or _tensors_ represented as `tf.Tensor` objects. Here is a two-dimensional tensor:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6ZqX5RnbBS1f"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "x = tf.constant([[1., 2., 3.],\n",
        "                 [4., 5., 6.]])\n",
        "\n",
        "print(x)\n",
        "print(x.shape)\n",
        "print(x.dtype)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k-AOMqevQGN4"
      },
      "source": [
        "The most important attributes of a `tf.Tensor` are its `shape` and `dtype`:\n",
        "\n",
        "* `Tensor.shape`: tells you the size of the tensor along each of its axes.\n",
        "* `Tensor.dtype`: tells you the type of all the elements in the tensor."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bUkKeNWZCIJO"
      },
      "source": [
        "TensorFlow implements standard mathematical operations on tensors, as well as many operations specialized for machine learning.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BM7xXNDsBfN5"
      },
      "outputs": [],
      "source": [
        "x + x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZLGqscTxB61v"
      },
      "outputs": [],
      "source": [
        "5 * x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ImJHd8VfnWq"
      },
      "outputs": [],
      "source": [
        "x @ tf.transpose(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U9JZD6TYCZWu"
      },
      "outputs": [],
      "source": [
        "tf.concat([x, x, x], axis=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "seGBLeD9P_PI"
      },
      "outputs": [],
      "source": [
        "tf.nn.softmax(x, axis=-1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YZNZRv1ECjf8"
      },
      "outputs": [],
      "source": [
        "tf.reduce_sum(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TNHnIjOVLJfA"
      },
      "source": [
        "Note: Typically, anywhere a TensorFlow function expects a `Tensor` as input, the function will also accept anything that can be converted to a `Tensor` using `tf.convert_to_tensor`. See below for an example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i_XKgjDsL4GE"
      },
      "outputs": [],
      "source": [
        "tf.convert_to_tensor([1,2,3])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wTBt-JUqLJDJ"
      },
      "outputs": [],
      "source": [
        "tf.reduce_sum([1,2,3])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8-mi5031DVxz"
      },
      "source": [
        "Running large calculations on CPU can be slow. When properly configured, TensorFlow can use accelerator hardware like GPUs to execute operations very quickly."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m97Gv5H6Dz0G"
      },
      "outputs": [],
      "source": [
        "if tf.config.list_physical_devices('GPU'):\n",
        "  print(\"TensorFlow **IS** using the GPU\")\n",
        "else:\n",
        "  print(\"TensorFlow **IS NOT** using the GPU\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ln2FkLOqMX92"
      },
      "source": [
        "Refer to the [Tensor guide](tensor.ipynb) for details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oVbomvMyEIVF"
      },
      "source": [
        "## Variables\n",
        "\n",
        "Normal `tf.Tensor` objects are immutable. To store model weights (or other mutable state) in TensorFlow use a `tf.Variable`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SO8_bP4UEzxS"
      },
      "outputs": [],
      "source": [
        "var = tf.Variable([0.0, 0.0, 0.0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aDLYFvu5FAFa"
      },
      "outputs": [],
      "source": [
        "var.assign([1, 2, 3])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9EpiOmxXFDSS"
      },
      "outputs": [],
      "source": [
        "var.assign_add([1, 1, 1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tlvTpi1CMedC"
      },
      "source": [
        "Refer to the [Variables guide](variable.ipynb) for details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rG1Dhv2QFkV3"
      },
      "source": [
        "## Automatic differentiation\n",
        "\n",
        "<a href=\"https://en.wikipedia.org/wiki/Gradient_descent\" class=\"external\">_Gradient descent_</a> and related algorithms are a cornerstone of modern machine learning.\n",
        "\n",
        "To enable this, TensorFlow implements automatic differentiation (autodiff), which uses calculus to compute gradients. Typically you'll use this to calculate the gradient of a model's _error_ or _loss_ with respect to its weights."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cYKOi-z4GY9Y"
      },
      "outputs": [],
      "source": [
        "x = tf.Variable(1.0)\n",
        "\n",
        "def f(x):\n",
        "  y = x**2 + 2*x - 5\n",
        "  return y"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IQz99cxMGoF_"
      },
      "outputs": [],
      "source": [
        "f(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ozLLop0cHeYl"
      },
      "source": [
        "At `x = 1.0`, `y = f(x) = (1**2 + 2*1 - 5) = -2`.\n",
        "\n",
        "The derivative of `y` is `y' = f'(x) = (2*x + 2) = 4`. TensorFlow can calculate this automatically:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N02NfWpHGvw8"
      },
      "outputs": [],
      "source": [
        "with tf.GradientTape() as tape:\n",
        "  y = f(x)\n",
        "\n",
        "g_x = tape.gradient(y, x)  # g(x) = dy/dx\n",
        "\n",
        "g_x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s-DVYJfcIRPd"
      },
      "source": [
        "This simplified example only takes the derivative with respect to a single scalar (`x`), but TensorFlow can compute the gradient with respect to any number of non-scalar tensors simultaneously."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ECK3I9bUMk_r"
      },
      "source": [
        "Refer to the [Autodiff guide](autodiff.ipynb) for details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VglUM4M3KhNz"
      },
      "source": [
        "## Graphs and tf.function\n",
        "\n",
        "While you can use TensorFlow interactively like any Python library, TensorFlow also provides tools for:\n",
        "\n",
        "* **Performance optimization**: to speed up training and inference.\n",
        "* **Export**: so you can save your model when it's done training.\n",
        "\n",
        "These require that you use `tf.function` to separate your pure-TensorFlow code from Python."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VitACyZWKJD_"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def my_func(x):\n",
        "  print('Tracing.\\n')\n",
        "  return tf.reduce_sum(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fBYDh-huNUBZ"
      },
      "source": [
        "The first time you run the `tf.function`, although it executes in Python, it captures a complete, optimized graph representing the TensorFlow computations done within the function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vkOFSEkoM1bd"
      },
      "outputs": [],
      "source": [
        "x = tf.constant([1, 2, 3])\n",
        "my_func(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a3aWzt-rNsBa"
      },
      "source": [
        "On subsequent calls TensorFlow only executes the optimized graph, skipping any non-TensorFlow steps. Below, note that `my_func` doesn't print _tracing_ since `print` is a Python function, not a TensorFlow function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "23dMHWwwNIoa"
      },
      "outputs": [],
      "source": [
        "x = tf.constant([10, 9, 8])\n",
        "my_func(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nSeTti6zki0n"
      },
      "source": [
        "A graph may not be reusable for inputs with a different _signature_ (`shape` and `dtype`), so a new graph is generated instead:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OWffqyhqlVPf"
      },
      "outputs": [],
      "source": [
        "x = tf.constant([10.0, 9.1, 8.2], dtype=tf.float32)\n",
        "my_func(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UWknAA_zNTOa"
      },
      "source": [
        "These captured graphs provide two benefits:\n",
        "\n",
        "* In many cases they provide a significant speedup in execution (though not this trivial example).\n",
        "* You can export these graphs, using `tf.saved_model`, to run on other systems like a [server](https://www.tensorflow.org/tfx/serving/docker) or a [mobile device](https://www.tensorflow.org/lite/guide), no Python installation required."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hLUJ6f2eMsA8"
      },
      "source": [
        "Refer to [Intro to graphs](intro_to_graphs.ipynb) for more details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t_36xPDPPBqp"
      },
      "source": [
        "## Modules, layers, and models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oDaT7kCpUgnJ"
      },
      "source": [
        "`tf.Module` is a class for managing your `tf.Variable` objects, and the `tf.function` objects that operate on them. The `tf.Module` class is necessary to support two significant features:\n",
        "\n",
        "1. You can save and restore the values of your variables using `tf.train.Checkpoint`. This is useful during training as it is quick to save and restore a model's state.\n",
        "2. You can import and export the `tf.Variable` values _and_ the `tf.function` graphs using `tf.saved_model`. This allows you to run your model independently of the Python program that created it.\n",
        "\n",
        "Here is a complete example exporting a simple `tf.Module` object:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1MqEcZOqPBDV"
      },
      "outputs": [],
      "source": [
        "class MyModule(tf.Module):\n",
        "  def __init__(self, value):\n",
        "    self.weight = tf.Variable(value)\n",
        "\n",
        "  @tf.function\n",
        "  def multiply(self, x):\n",
        "    return x * self.weight"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "la2G82HfVfU0"
      },
      "outputs": [],
      "source": [
        "mod = MyModule(3)\n",
        "mod.multiply(tf.constant([1, 2, 3]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GaSJX7zQXCm4"
      },
      "source": [
        "Save the `Module`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1MlfbEMjVzG4"
      },
      "outputs": [],
      "source": [
        "save_path = './saved'\n",
        "tf.saved_model.save(mod, save_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LgfoftD4XGJW"
      },
      "source": [
        "The resulting SavedModel is independent of the code that created it. You can load a SavedModel from Python, other language bindings, or [TensorFlow Serving](https://www.tensorflow.org/tfx/serving/docker). You can also convert it to run with [TensorFlow Lite](https://www.tensorflow.org/lite/guide) or [TensorFlow JS](https://www.tensorflow.org/js/guide)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pWuLOIKBWZYG"
      },
      "outputs": [],
      "source": [
        "reloaded = tf.saved_model.load(save_path)\n",
        "reloaded.multiply(tf.constant([1, 2, 3]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nxU6P1RGwHyC"
      },
      "source": [
        "The `tf.keras.layers.Layer` and `tf.keras.Model` classes build on `tf.Module` providing additional functionality and convenience methods for building, training, and saving models. Some of these are demonstrated in the next section."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tQzt3yaWMzLf"
      },
      "source": [
        "Refer to [Intro to modules](intro_to_modules.ipynb) for details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rk1IEG5aav7X"
      },
      "source": [
        "## Training loops\n",
        "\n",
        "Now put this all together to build a basic model and train it from scratch.\n",
        "\n",
        "First, create some example data. This generates a cloud of points that loosely follows a quadratic curve:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VcuFr7KPRPzn"
      },
      "outputs": [],
      "source": [
        "import matplotlib\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "matplotlib.rcParams['figure.figsize'] = [9, 6]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sXN9E_xf-GiP"
      },
      "outputs": [],
      "source": [
        "x = tf.linspace(-2, 2, 201)\n",
        "x = tf.cast(x, tf.float32)\n",
        "\n",
        "def f(x):\n",
        "  y = x**2 + 2*x - 5\n",
        "  return y\n",
        "\n",
        "y = f(x) + tf.random.normal(shape=[201])\n",
        "\n",
        "plt.plot(x.numpy(), y.numpy(), '.', label='Data')\n",
        "plt.plot(x, f(x), label='Ground truth')\n",
        "plt.legend();"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "De5LldboSWcW"
      },
      "source": [
        "Create a quadratic model with randomly initialized weights and a bias:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pypd0GB4SRhf"
      },
      "outputs": [],
      "source": [
        "class Model(tf.Module):\n",
        "\n",
        "  def __init__(self):\n",
        "    # Randomly generate weight and bias terms\n",
        "    rand_init = tf.random.uniform(shape=[3], minval=0., maxval=5., seed=22)\n",
        "    # Initialize model parameters\n",
        "    self.w_q = tf.Variable(rand_init[0])\n",
        "    self.w_l = tf.Variable(rand_init[1])\n",
        "    self.b = tf.Variable(rand_init[2])\n",
        "  \n",
        "  @tf.function\n",
        "  def __call__(self, x):\n",
        "    # Quadratic Model : quadratic_weight * x^2 + linear_weight * x + bias\n",
        "    return self.w_q * (x**2) + self.w_l * x + self.b"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "36o7VjaesScg"
      },
      "source": [
        "First, observe your model's performance before training:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GkwToC5BWV1c"
      },
      "outputs": [],
      "source": [
        "quad_model = Model()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ReWhH40wTY5F"
      },
      "outputs": [],
      "source": [
        "def plot_preds(x, y, f, model, title):\n",
        "  plt.figure()\n",
        "  plt.plot(x, y, '.', label='Data')\n",
        "  plt.plot(x, f(x), label='Ground truth')\n",
        "  plt.plot(x, model(x), label='Predictions')\n",
        "  plt.title(title)\n",
        "  plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y0JtXQat-nlk"
      },
      "outputs": [],
      "source": [
        "plot_preds(x, y, f, quad_model, 'Before training')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hLzwD0-ascGf"
      },
      "source": [
        "Now, define a loss for your model:\n",
        "\n",
        "Given that this model is intended to predict continuous values, the mean squared error (MSE) is a good choice for the loss function. Given a vector of predictions, $\\hat{y}$, and a vector of true targets, $y$, the MSE is defined as the mean of the squared differences between the predicted values and the ground truth.\n",
        "\n",
        "$MSE = \\frac{1}{m}\\sum_{i=1}^{m}(\\hat{y}_i -y_i)^2$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eCtJ1uuCseZd"
      },
      "outputs": [],
      "source": [
        "def mse_loss(y_pred, y):\n",
        "  return tf.reduce_mean(tf.square(y_pred - y))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7EWyDu3zot2w"
      },
      "source": [
        "Write a basic training loop for the model. The loop will make use of the MSE loss function and its gradients with respect to the input in order to iteratively update the model's parameters. Using mini-batches for training provides both memory efficienciy and faster convergence. The `tf.data.Dataset` API has useful functions for batching and shuffling."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8kX_-zily2Ia"
      },
      "outputs": [],
      "source": [
        "batch_size = 32\n",
        "dataset = tf.data.Dataset.from_tensor_slices((x, y))\n",
        "dataset = dataset.shuffle(buffer_size=x.shape[0]).batch(batch_size)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nOaES5gyTDtG"
      },
      "outputs": [],
      "source": [
        "# Set training parameters\n",
        "epochs = 100\n",
        "learning_rate = 0.01\n",
        "losses = []\n",
        "\n",
        "# Format training loop\n",
        "for epoch in range(epochs):\n",
        "  for x_batch, y_batch in dataset:\n",
        "    with tf.GradientTape() as tape:\n",
        "      batch_loss = mse_loss(quad_model(x_batch), y_batch)\n",
        "    # Update parameters with respect to the gradient calculations\n",
        "    grads = tape.gradient(batch_loss, quad_model.variables)\n",
        "    for g,v in zip(grads, quad_model.variables):\n",
        "        v.assign_sub(learning_rate*g)\n",
        "  # Keep track of model loss per epoch\n",
        "  loss = mse_loss(quad_model(x), y)\n",
        "  losses.append(loss)\n",
        "  if epoch % 10 == 0:\n",
        "    print(f'Mean squared error for step {epoch}: {loss.numpy():0.3f}')\n",
        "\n",
        "# Plot model results\n",
        "print(\"\\n\")\n",
        "plt.plot(range(epochs), losses)\n",
        "plt.xlabel(\"Epoch\")\n",
        "plt.ylabel(\"Mean Squared Error (MSE)\")\n",
        "plt.title('MSE loss vs training iterations');"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dW5B2TTRsvxE"
      },
      "source": [
        "Now, observe your model's performance after training:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qcvzyg3eYLh8"
      },
      "outputs": [],
      "source": [
        "plot_preds(x, y, f, quad_model, 'After training')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hbtmFJIXb6qm"
      },
      "source": [
        "That's working, but remember that implementations of common training utilities are available in the `tf.keras` module. So, consider using those before writing your own. To start with, the `Model.compile` and `Model.fit` methods implement a training loop for you:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cjx23MiztFmT"
      },
      "source": [
        "Begin by creating a Sequential Model in Keras using `tf.keras.Sequential`. One of the simplest Keras layers is the dense layer, which can be instantiated with `tf.keras.layers.Dense`. The dense layer is able to learn multidimensional linear relationships of the form $\\mathrm{Y} = \\mathrm{W}\\mathrm{X} +  \\vec{b}$. In order to learn a nonlinear equation of the form, $w_1x^2 + w_2x + b$, the dense layer's input should be a data matrix with $x^2$ and $x$ as features. The lambda layer, `tf.keras.layers.Lambda`, can be used to perform this stacking transformation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5rt8HP2TZhEM"
      },
      "outputs": [],
      "source": [
        "new_model = tf.keras.Sequential([\n",
        "    tf.keras.layers.Lambda(lambda x: tf.stack([x, x**2], axis=1)),\n",
        "    tf.keras.layers.Dense(units=1, kernel_initializer=tf.random.normal)])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "73kCo1BtP3rQ"
      },
      "outputs": [],
      "source": [
        "new_model.compile(\n",
        "    loss=tf.keras.losses.MSE,\n",
        "    optimizer=tf.keras.optimizers.SGD(learning_rate=0.01))\n",
        "\n",
        "history = new_model.fit(x, y,\n",
        "                        epochs=100,\n",
        "                        batch_size=32,\n",
        "                        verbose=0)\n",
        "\n",
        "new_model.save('./my_new_model')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u3q5d1SzvzTq"
      },
      "source": [
        "Observe your Keras model's performance after training:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mo7zRV7XZjv7"
      },
      "outputs": [],
      "source": [
        "plt.plot(history.history['loss'])\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylim([0, max(plt.ylim())])\n",
        "plt.ylabel('Loss [Mean Squared Error]')\n",
        "plt.title('Keras training progress');"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bB44a9YsvnfK"
      },
      "outputs": [],
      "source": [
        "plot_preds(x, y, f, new_model, 'After Training: Keras')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ng-BY_eGS0bn"
      },
      "source": [
        "Refer to [Basic training loops](basic_training_loops.ipynb) and the [Keras guide](https://www.tensorflow.org/guide/keras) for more details."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "basics.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
